Transformer的其它组件

Note

下图是transformer的结构图,我们已经描述并实现了其中的多头注意力(Multi-head attention)和位置编码(Positional encoding)。
本节我们来讲tranformer的另外的组件:基于位置的前馈网络(Positionwise FFN)、残差连接和层归一化(Add & norm)。

jupyter

基于位置的前馈网络

即对序列中所有位置的表示进行变换时,使用的是同一个多层感知机(MLP)。

import torch
from torch import nn


#@save
class PositionWiseFFN(nn.Module):
    """基于位置的前馈网络"""
    def __init__(self, ffn_num_input, ffn_num_hiddens):
        super(PositionWiseFFN, self).__init__()
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_input)

    def forward(self, X):
        # X shape: (`batch_size`, `num_steps`, `ffn_num_input`)
        # 输入和输出的形状一样
        return self.dense2(self.relu(self.dense1(X)))

残差连接和层归一化

此组件由残差连接和紧随其后的层归一化组成,两者都是构建有效的深度结构的关键。

层归一化和批量归一化(Batch Normalization)的目标相同,但层归一化的均值和方差在最后几个维度上进行计算。在自然语言处理任务中批量归一化通常不如层归一化效果好。

ln = nn.LayerNorm(3)
bn = nn.BatchNorm1d(3)
X = torch.tensor([[1, 2, 3], [4, 6, 8]], dtype=torch.float32)
# 层归一化
ln(X), (X - X.mean(axis=1).reshape(-1, 1)) / X.std(axis=1, unbiased=False).reshape(-1, 1)
(tensor([[-1.2247,  0.0000,  1.2247],
         [-1.2247,  0.0000,  1.2247]], grad_fn=<NativeLayerNormBackward>),
 tensor([[-1.2247,  0.0000,  1.2247],
         [-1.2247,  0.0000,  1.2247]]))
# 批量归一化
bn(X), (X - X.mean(axis=0)) / X.std(axis=0, unbiased=False)
(tensor([[-1.0000, -1.0000, -1.0000],
         [ 1.0000,  1.0000,  1.0000]], grad_fn=<NativeBatchNormBackward>),
 tensor([[-1., -1., -1.],
         [ 1.,  1.,  1.]]))
#@save
class AddNorm(nn.Module):
    """残差连接和层归一化"""
    def __init__(self, normalized_shape, dropout):
        super(AddNorm, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # normalized_shape指定均值和方差计算的维度,需是后几个维度
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self, X, Y):
        # 先残差连接,再层归一化
        return self.ln(self.dropout(Y) + X)
add_norm = AddNorm([3, 4], 0.5)
add_norm.eval()
# 形状不变
add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape
torch.Size([2, 3, 4])